from mindformers import MindFormerConfig
from mindformers.models.vit.vit_new import ViTNewModel
import mindspore as ms
import numpy as np
import random
from mindformers import AutoConfig
ms.set_context(mode=1)

np.random.seed(1)

inputs = np.random.rand(1, 3, 8, 224, 224)

inputs = ms.Tensor(inputs, ms.float32)

yaml_path = "/home/zhangyouwen/suite/mobile_commucation/mindformers/configs/vit/run_vit_base_p16_224_100ep.yaml"

config = AutoConfig.from_pretrained("vit_base_p16")

config.depth = 23
config.embed_dim = 1024
config.num_heads = 16
config.use_mean_pooling = False

vit_model = ViTNewModel(config)

ckpt_path = "/home/zhangyouwen/suite/mobile_commucation/videochat2_7b_stage2.ckpt"
state_dict = ms.load_checkpoint(ckpt_path)

ms.load_param_into_net(vit_model, state_dict)

output = vit_model(inputs)

print(output)
